import numpy as np
from numpy.testing import assert_equal

from menpofit.transform import DifferentiableAffine


jac_solution2d = np.array(
    [[[0., 0.],
      [0., 0.],
      [0., 0.],
      [0., 0.],
      [1., 0.],
      [0., 1.]],
     [[0., 0.],
      [0., 0.],
      [1., 0.],
      [0., 1.],
      [1., 0.],
      [0., 1.]],
     [[0., 0.],
      [0., 0.],
      [2., 0.],
      [0., 2.],
      [1., 0.],
      [0., 1.]],
     [[1., 0.],
      [0., 1.],
      [0., 0.],
      [0., 0.],
      [1., 0.],
      [0., 1.]],
     [[1., 0.],
      [0., 1.],
      [1., 0.],
      [0., 1.],
      [1., 0.],
      [0., 1.]],
     [[1., 0.],
      [0., 1.],
      [2., 0.],
      [0., 2.],
      [1., 0.],
      [0., 1.]]])

jac_solution3d = np.array(
    [[[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [2., 0., 0.],
      [0., 2., 0.],
      [0., 0., 2.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [2., 0., 0.],
      [0., 2., 0.],
      [0., 0., 2.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [2., 0., 0.],
      [0., 2., 0.],
      [0., 0., 2.],
      [0., 0., 0.],
      [0., 0., 0.],
      [0., 0., 0.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]],
     [[1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [2., 0., 0.],
      [0., 2., 0.],
      [0., 0., 2.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.],
      [1., 0., 0.],
      [0., 1., 0.],
      [0., 0., 1.]]])


def test_affine_jacobian_2d_with_positions():
    params = np.array([0, 0.1, 0.2, 0, 30, 70])
    t = DifferentiableAffine.init_identity(2).from_vector(params)
    explicit_pixel_locations = np.array(
        [[0, 0],
         [0, 1],
         [0, 2],
         [1, 0],
         [1, 1],
         [1, 2]])
    dW_dp = t.d_dp(explicit_pixel_locations)
    assert_equal(dW_dp, jac_solution2d)


def test_affine_jacobian_3d_with_positions():
    params = np.ones(12)
    t = DifferentiableAffine.init_identity(3).from_vector(params)
    explicit_pixel_locations = np.array(
        [[0, 0, 0],
         [0, 0, 1],
         [0, 1, 0],
         [0, 1, 1],
         [0, 2, 0],
         [0, 2, 1],
         [1, 0, 0],
         [1, 0, 1],
         [1, 1, 0],
         [1, 1, 1],
         [1, 2, 0],
         [1, 2, 1]])
    dW_dp = t.d_dp(explicit_pixel_locations)
    assert_equal(dW_dp, jac_solution3d)
